{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "N6ZDpd9XzFeN"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "KUu4vOt5zI9d"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CxmDMK4yupqg"
      },
      "source": [
        "# TF-Hub generative image model\n",
        "\n",
        "\u003ctable align=\"left\"\u003e\u003ctd\u003e\n",
        "  \u003ca target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf_hub_generative_image_module.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\n",
        "  \u003c/a\u003e\n",
        "\u003c/td\u003e\u003ctd\u003e\n",
        "  \u003ca target=\"_blank\"  href=\"https://github.com/tensorflow/hub/blob/master/examples/colab/tf_hub_generative_image_module.ipynb\"\u003e\n",
        "    \u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "\u003c/td\u003e\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Sy553YSVmYiK"
      },
      "source": [
        "This Colab demonstrates use of a TF-Hub module based on a generative adversarial network (GAN). The module maps from N-dimensional vectors, called latent space, to RGB images.\n",
        "\n",
        "Two examples are provided:\n",
        "* **Mapping** from latent space to images, and\n",
        "* Given a target image, **using gradient descent to find** a latent vector that generates an image similar to the target image."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v4XGxDrCkeip"
      },
      "source": [
        "## Optional prerequisites\n",
        "\n",
        "* Familiarity with [low level Tensorflow concepts](https://www.tensorflow.org/guide/low_level_intro).\n",
        "* [Generative Adversarial Network](https://en.wikipedia.org/wiki/Generative_adversarial_network) on Wikipedia.\n",
        "* Paper on Progressive GANs: [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://arxiv.org/abs/1710.10196)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "KNM3kA0arrUu"
      },
      "outputs": [],
      "source": [
        "# Install the latest Tensorflow version.\n",
        "!pip -q install --quiet \"tensorflow\u003e=1.7\"\n",
        "# Install TF-Hub.\n",
        "!pip -q install tensorflow-hub\n",
        "# Install imageio for creating animations.  \n",
        "!pip -q install imageio\n",
        "!pip -q install scikit-image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "6cPY9Ou4sWs_"
      },
      "outputs": [],
      "source": [
        "#@title Imports and function definitions\n",
        "\n",
        "from absl import logging\n\n",
        "import imageio\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import time\n",
        "\n",
        "try:\n",
        "  from google.colab import files\n",
        "except ImportError:\n",
        "  pass\n",
        "\n",
        "from IPython import display\n",
        "from skimage import transform\n",
        "\n",
        "# We could retrieve this value from module.get_input_shapes() if we didn't know\n",
        "# beforehand which module we will be using.\n",
        "latent_dim = 512\n",
        "\n",
        "\n",
        "# Interpolates between two vectors that are non-zero and don't both lie on a\n",
        "# line going through origin. First normalizes v2 to have the same norm as v1. \n",
        "# Then interpolates between the two vectors on the hypersphere.\n",
        "def interpolate_hypersphere(v1, v2, num_steps):\n",
        "  v1_norm = tf.norm(v1)\n",
        "  v2_norm = tf.norm(v2)\n",
        "  v2_normalized = v2 * (v1_norm / v2_norm)\n",
        "\n",
        "  vectors = []\n",
        "  for step in range(num_steps):\n",
        "    interpolated = v1 + (v2_normalized - v1) * step / (num_steps - 1)\n",
        "    interpolated_norm = tf.norm(interpolated)\n",
        "    interpolated_normalized = interpolated * (v1_norm / interpolated_norm)\n",
        "    vectors.append(interpolated_normalized)\n",
        "  return tf.stack(vectors)\n",
        "\n",
        "\n",
        "# Given a set of images, show an animation.\n",
        "def animate(images):\n",
        "  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)\n",
        "  imageio.mimsave('./animation.gif', converted_images)\n",
        "  with open('./animation.gif','rb') as f:\n",
        "      display.display(display.Image(data=f.read(), height=300))\n",
        "\n",
        "\n",
        "# Simple way to display an image.\n",
        "def display_image(image):\n",
        "  plt.figure()\n",
        "  plt.axis(\"off\")\n",
        "  plt.imshow(image)\n",
        "\n",
        "\n",
        "# Display multiple images in the same figure.\n",
        "def display_images(images, captions=None):\n",
        "  num_horizontally = 5\n",
        "  f, axes = plt.subplots(\n",
        "      len(images) // num_horizontally, num_horizontally, figsize=(20, 20))\n",
        "  for i in range(len(images)):\n",
        "    axes[i // num_horizontally, i % num_horizontally].axis(\"off\")\n",
        "    if captions is not None:\n",
        "      axes[i // num_horizontally, i % num_horizontally].text(0, -3, captions[i])\n",
        "    axes[i // num_horizontally, i % num_horizontally].imshow(images[i])\n",
        "  f.tight_layout()\n",
        "\n",
        "logging.set_verbosity(logging.ERROR)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f5EESfBvukYI"
      },
      "source": [
        "## Latent space interpolation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nJb9gFmRvynZ"
      },
      "source": [
        "### Random vectors\n",
        "\n",
        "Latent space interpolation between two randomly initialized vectors. We will use a TF-Hub module [progan-128](https://tfhub.dev/google/progan-128/1) that contains a pre-trained Progressive GAN."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "fZ0O5_5Jhwio"
      },
      "outputs": [],
      "source": [
        "def interpolate_between_vectors():\n",
        "  with tf.Graph().as_default():\n",
        "    module = hub.Module(\"https://tfhub.dev/google/progan-128/1\")\n",
        "\n",
        "    # Change the seed to get different random vectors.\n",
        "    v1 = tf.random_normal([latent_dim], seed=3)\n",
        "    v2 = tf.random_normal([latent_dim], seed=1)\n",
        "    \n",
        "    # Creates a tensor with 50 steps of interpolation between v1 and v2.\n",
        "    vectors = interpolate_hypersphere(v1, v2, 25)\n",
        "\n",
        "    # Uses module to generate images from the latent space.\n",
        "    interpolated_images = module(vectors)\n",
        "\n",
        "    with tf.Session() as session:\n",
        "      session.run(tf.global_variables_initializer())\n",
        "      interpolated_images_out = session.run(interpolated_images)\n",
        "\n",
        "    animate(interpolated_images_out)\n",
        "\n",
        "interpolate_between_vectors()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "L9-uXoTHuXQC"
      },
      "source": [
        "## Finding closest vector in latent space\n",
        "Fix a target image. As an example use an image generated from the module or upload your own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "phT4W66pMmko"
      },
      "outputs": [],
      "source": [
        "image_from_module_space = True  # @param { isTemplate:true, type:\"boolean\" }\n",
        "\n",
        "def get_module_space_image():\n",
        "  with tf.Graph().as_default():\n",
        "    module = hub.Module(\"https://tfhub.dev/google/progan-128/1\")\n",
        "    vector = tf.random_normal([1, latent_dim], seed=4)\n",
        "    images = module(vector)\n",
        "\n",
        "    with tf.Session() as session:\n",
        "      session.run(tf.global_variables_initializer())\n",
        "      image_out = session.run(images)[0]\n",
        "  return image_out\n",
        "\n",
        "def upload_image():\n",
        "  uploaded = files.upload()\n",
        "  image = imageio.imread(uploaded[list(uploaded.keys())[0]])\n",
        "  return transform.resize(image, [128, 128])\n",
        "\n",
        "if image_from_module_space:\n",
        "  target_image = get_module_space_image()\n",
        "else:\n",
        "  target_image = upload_image()\n",
        "display_image(target_image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rBIt3Q4qvhuq"
      },
      "source": [
        "After defining a loss function between the target image and the image generated by a latent space variable, we can use gradient descent to find variable values that minimize the loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "q_4Z7tnyg-ZY"
      },
      "outputs": [],
      "source": [
        "def find_closest_latent_vector(num_optimization_steps):\n",
        "  images = []\n",
        "  losses = []\n",
        "  with tf.Graph().as_default():\n",
        "    module = hub.Module(\"https://tfhub.dev/google/progan-128/1\")\n",
        "\n",
        "    initial_vector = tf.random_normal([1, latent_dim], seed=5)\n",
        "\n",
        "    vector = tf.get_variable(\"vector\", initializer=initial_vector)\n",
        "    image = module(vector)\n",
        "\n",
        "    target_image_difference = tf.reduce_sum(\n",
        "        tf.losses.absolute_difference(image[0], target_image[:,:,:3]))\n",
        "\n",
        "    # The latent vectors were sampled from a normal distribution. We can get\n",
        "    # more realistic images if we regularize the length of the latent vector to \n",
        "    # the average length of vector from this distribution.\n",
        "    regularizer = tf.abs(tf.norm(vector) - np.sqrt(latent_dim))\n",
        "    \n",
        "    loss = target_image_difference + regularizer\n",
        "    \n",
        "    optimizer = tf.train.AdamOptimizer(learning_rate=0.3)\n",
        "    train = optimizer.minimize(loss)\n",
        "\n",
        "    with tf.Session() as session:\n",
        "      session.run(tf.global_variables_initializer())\n",
        "      for _ in range(num_optimization_steps):\n",
        "        _, loss_out, im_out = session.run([train, loss, image])\n",
        "        images.append(im_out[0])\n",
        "        losses.append(loss_out)\n",
        "        print(loss_out)\n",
        "    return images, losses\n",
        "\n",
        "\n",
        "result = find_closest_latent_vector(num_optimization_steps=40)\n",
        "display_images(result[0], [(\"Loss: %.2f\" % loss) for loss in result[1]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tDt15dLsJwMy"
      },
      "source": [
        "### Playing with the above example\n",
        "If image is from the module space, the descent is quick and converges to a reasonable sample. Try out descending to an image that is **not from the module space**. The descent will only converge if the image is reasonably close to the space of training images.\n",
        "\n",
        "How to make it descend faster and to a more realistic image? One can try:\n",
        "* using different loss on the image difference, e.g. quadratic,\n",
        "* using different regularizer on the latent vector,\n",
        "* initializing from a random vector in multiple runs,\n",
        "* etc.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "N6ZDpd9XzFeN"
      ],
      "default_view": {},
      "name": "TF-Hub generative image module",
      "provenance": [],
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
